import numpy as np
import h5py
import matplotlib.pyplot as plt
import os
import bluepysnap


def plot_microcircuit_with_postsyn_cells(sim):
    node_population = sim.circuit.nodes["thalamus_neurons"]
    df_Rt = node_population.positions(sim.circuit.node_sets["mc2;Rt"])
    df_VPL = node_population.positions(sim.circuit.node_sets["mc2;VPL"])

    # Efferent nodes are nodes receiving an incoming edge from one of the source node.
    UP_file = "../../simulations/spike_inputs/CT/CT_UPstate_90percent_subset_3500-4000ms_5Hz.h5"
    with h5py.File(UP_file, "r") as h5f:
        dataset = h5f["spikes"]
        CT_gids_subset = dataset["CorticoThalamic_projections"]["node_ids"][()]

    # use CT gid 894 as an example presynaptic CT - CT_gids_subset[2001]
    thal_postsyn_nodes = sim.circuit.edges[
        "CorticoThalamic_projections__thalamus_neurons__chemical"
    ].efferent_nodes(CT_gids_subset[2000], unique=True)
    mc2_col_thal_gids = np.concatenate(
        (
            sim.circuit.nodes["thalamus_neurons"].ids({"mtype": "VPL_TC", "region": "mc2;VPL"}),
            sim.circuit.nodes["thalamus_neurons"].ids({"mtype": "Rt_RC", "region": "mc2;Rt"}),
        )
    )
    thal_postsyn_nodes = np.intersect1d(mc2_col_thal_gids, thal_postsyn_nodes)
    df_postsyn_thal_cells = sim.circuit.nodes["thalamus_neurons"].get(
        thal_postsyn_nodes, properties=["x", "y", "z"]
    )

    # stack the plots
    fig, axes = plt.subplots(nrows=1, ncols=2, sharex=True, figsize=(8, 4))
    POINT_SIZE = 2
    df_VPL.plot(
        x="x",
        y="z",
        color="orange",
        s=POINT_SIZE,
        kind="scatter",
        ax=axes[0],
        alpha=0.5,
        marker=".",
        rasterized=True,
    )
    df_VPL.plot(
        x="x",
        y="y",
        color="orange",
        s=POINT_SIZE,
        kind="scatter",
        ax=axes[1],
        alpha=0.5,
        marker=".",
        rasterized=True,
    )

    df_Rt.plot(
        x="x",
        y="z",
        color="royalblue",
        s=POINT_SIZE,
        kind="scatter",
        ax=axes[0],
        alpha=0.5,
        marker=".",
        rasterized=True,
    )
    df_Rt.plot(
        x="x",
        y="y",
        color="royalblue",
        s=POINT_SIZE,
        kind="scatter",
        ax=axes[1],
        alpha=0.5,
        marker=".",
        rasterized=True,
    )

    df_postsyn_thal_cells.plot(
        x="x", y="z", color="r", s=10, kind="scatter", ax=axes[0], marker="."
    )
    df_postsyn_thal_cells.plot(
        x="x", y="y", color="r", s=10, kind="scatter", ax=axes[1], marker="."
    )

    axes[0].legend(["VPL", "TRN"], numpoints=1)

    # ensure consistent scale on all axes
    axes[0].axis("square")
    axes[1].axis("equal")
    """
    axes[0].spines.top.set_visible(False)
    axes[0].spines.right.set_visible(False)
    axes[0].spines.bottom.set_visible(False)
    axes[0].spines.left.set_visible(False)
    axes[1].spines.top.set_visible(False)
    axes[1].spines.right.set_visible(False)
    axes[1].spines.bottom.set_visible(False)
    axes[1].spines.left.set_visible(False)
    """

    # Remove ticks/values
    """
    for ax in axes:
        ax.tick_params(
            left=False,
            labelleft=False,
            top=False,
            labeltop=False,
            right=False,
            labelright=False,
            bottom=False,
            labelbottom=False,
        )
        ax.set_xlabel("")
        ax.set_ylabel("")
    """

    fig.tight_layout()
    fig.subplots_adjust(hspace=-0.1)
    plt.tight_layout()
    fig.savefig("column.pdf")


if __name__ == "__main__":
    plot_microcircuit_with_postsyn_cells(
        bluepysnap.Simulation(
            "../../simulations/08_08_Cav_variants/variant1_Ecel/variant1_Ecel_SUPER_LONG_90percent_CT_up_down/simulation_config.json"
        )
    )
